{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a href=\"https://colab.research.google.com/github/oughtinc/ergo/blob/notebooks-readme/notebooks/covid-19-inference.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Notes\n",
    "\n",
    "* Switch to Italy\n",
    "* Graph data and results\n",
    "* Add variable for true initial infections, lockdown start date (11 March)\n",
    "* Add lag time (see Jacob/NYT models)\n",
    "* Add patient recovery"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Setup"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Install [Ergo](https://github.com/oughtinc/ergo) (our forecasting library) and a few tools we'll use in this colab:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install --quiet poetry  # Fixes https://github.com/python-poetry/poetry/issues/532\n",
    "!pip install --quiet pendulum seaborn\n",
    "!pip install --quiet torchviz"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !pip uninstall --yes --quiet ergo\n",
    "!pip install  --quiet git+https://github.com/oughtinc/ergo.git@william\n",
    "# !pip install --upgrade --no-cache-dir --quiet git+https://github.com/oughtinc/ergo.git\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import ergo\n",
    "\n",
    "confirmed_infections = ergo.data.covid19.ConfirmedInfections()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext google.colab.data_table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "import ergo\n",
    "import pendulum\n",
    "import pandas\n",
    "import seaborn\n",
    "\n",
    "from types import SimpleNamespace\n",
    "from typing import List\n",
    "from pendulum import DateTime\n",
    "from matplotlib import pyplot"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Questions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here are Metaculus ids for the questions we'll load, and some short names that will allow us to associate questions with variables in our model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "question_ids = [3704, 3712, 3713, 3711, 3722, 3761, 3705, 3706]  # 3740, 3736, \n",
    "question_names = [\n",
    "  # \"WHO Eastern Mediterranean Region on 2020/03/27\",\n",
    "  # \"WHO Region of the Americas on 2020/03/27\",\n",
    "  # \"WHO Western Pacific Region on 2020/03/27\",\n",
    "  # \"WHO South-East Asia Region on 2020/03/27\",\n",
    "  \"South Korea on 2020/03/27\",\n",
    "  # \"United Kingdom on 2020/03/27\",\n",
    "  # \"WHO African Region on 2020/03/27\",\n",
    "  # \"WHO European Region on 2020/03/27\",\n",
    "  # \"Bay Area on 2020/04/01\",\n",
    "  # \"San Francisco on 2020/04/01\"\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We load the question data from Metaculus:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "metaculus = ergo.Metaculus(username=\"ought\", password=\"\")\n",
    "questions = [metaculus.get_question(id, name=name) for id, name in zip(question_ids, question_names)]\n",
    "ergo.MetaculusQuestion.to_dataframe(questions)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Our most important data is the data about confirmed cases (from Hopkins):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "confirmed_infections = ergo.data.covid19.ConfirmedInfections()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Assumptions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Assumptions are things that should be inferred from data but currently aren't:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "assumptions = SimpleNamespace()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "assumptions.lockdown_start = {\n",
    "  \"Italy\": pendulum.datetime(2020,3,11),\n",
    "  \"Spain\": pendulum.datetime(2020,3,15),\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Main model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import pyro\n",
    "\n",
    "Area = str\n",
    "\n",
    "@ergo.model\n",
    "def model(start: DateTime, end: DateTime, areas: List[Area], training=True):\n",
    "  for area in areas:\n",
    "    doubling_time = ergo.lognormal_from_interval(1., 14., name=f\"doubling_time {area}\")\n",
    "    doubling_time_lockdown = ergo.lognormal_from_interval(1., torch.max(doubling_time, ergo.to_float(1.1)), name=f\"doubling_time_lockdown {area}\")\n",
    "    observation_noise = ergo.halfnormal_from_interval(0.1, name=f\"observation_noise {area}\")\n",
    "    predicted = ergo.to_float(confirmed_infections(area, start))\n",
    "\n",
    "    for i in range(1,(end - start).days):\n",
    "      date = start.add(days=i)\n",
    "      datestr = date.format('YYYY/MM/DD')\n",
    "      confirmed = None\n",
    "      try:\n",
    "        confirmed = ergo.to_float(confirmed_infections(area, date))\n",
    "        ergo.tag(confirmed, f\"actual {area} on {datestr}\")\n",
    "      except KeyError:\n",
    "        pass\n",
    "      doubling_time_today = doubling_time\n",
    "      if area in assumptions.lockdown_start.keys() and date >= assumptions.lockdown_start[area]:\n",
    "        doubling_time_today = doubling_time_lockdown\n",
    "      predicted = predicted * 2**(1. / doubling_time_today)\n",
    "      ergo.tag(predicted, f\"predicted {area} on {datestr}\")\n",
    "      if (not training) or (confirmed is not None):\n",
    "        predict_observed = ergo.normal(predicted, \n",
    "                                      predicted*observation_noise, \n",
    "                                      name=f\"predict_observed {area} on {datestr}\",\n",
    "                                      obs=confirmed)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Run the model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "start_date = pendulum.datetime(2020,3,1)\n",
    "end_date = pendulum.datetime(2020,4,1)\n",
    "areas = [\"Italy\", \"Spain\"]\n",
    "model_args = (start_date, end_date, areas)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from pyro.infer import SVI, Trace_ELBO\n",
    "from pyro.infer import Predictive\n",
    "import functools\n",
    "\n",
    "def infer_and_run(model, num_samples=5000, num_iterations=2000, \n",
    "                  debug=False, learning_rate=0.01, \n",
    "                  early_stopping_patience=200) -> pd.DataFrame:\n",
    "  \"\"\"\n",
    "    debug - whether to output debug information\n",
    "    num_iterations - Number of optimizer iterations\n",
    "    learning_rate - Optimizer learning rate\n",
    "    early_stopping_patience - Stop training if loss hasn't improved for this many iterations\n",
    "  \"\"\"\n",
    "  def to_numpy(d):\n",
    "    return {k:v.detach().numpy() for k, v in d.items()}\n",
    "\n",
    "  def debug_output(guide):\n",
    "    quantiles = to_numpy(guide.quantiles([0.05, 0.5, 0.95]))\n",
    "    for k, v in quantiles.items():\n",
    "      print(f\"{k}: {v[1]:.4f} [{v[0]:.4f}, {v[2]:.4f}]\")\n",
    "\n",
    "  guide = pyro.infer.autoguide.AutoNormal(model, \n",
    "                                          init_loc_fn=pyro.infer.autoguide.init_to_median)\n",
    "  pyro.clear_param_store()\n",
    "  guide(training=True)\n",
    "  adam = pyro.optim.Adam({\"lr\": 0.01})\n",
    "  svi = SVI(model, guide, adam, loss=Trace_ELBO())\n",
    "\n",
    "  if debug:\n",
    "    debug_output(guide)\n",
    "    print()\n",
    "\n",
    "  best_loss = None\n",
    "  last_improvement = None\n",
    "\n",
    "  for j in range(num_iterations):\n",
    "      # calculate the loss and take a gradient step\n",
    "      loss = svi.step(training=True)\n",
    "      if best_loss is None or best_loss > loss:\n",
    "        best_loss = loss\n",
    "        last_improvement = j\n",
    "      if j % 100 == 0:\n",
    "        if debug:\n",
    "          print(\"[iteration %04d]\" % (j + 1 ))\n",
    "          print(f\"loss: {loss:.4f}\")\n",
    "          debug_output(guide)\n",
    "          print()\n",
    "        if j > (last_improvement + early_stopping_patience):\n",
    "          print(\"Stopping Early\")\n",
    "          break\n",
    "\n",
    "  print(f\"Final loss: {loss:.4f}\")\n",
    "  predictive = Predictive(model, guide=guide, num_samples=num_samples)\n",
    "  raw_samples = predictive(training=False)\n",
    "  return pandas.DataFrame(to_numpy(raw_samples))\n",
    "\n",
    "\n",
    "\n",
    "samples = infer_and_run(functools.partial(model, *model_args),\n",
    "                        num_iterations=1000, \n",
    "                        num_samples=1000,\n",
    "                        debug=True)\n",
    "samples.describe().transpose()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datetime import datetime\n",
    "\n",
    "to_plot = [\n",
    "  # (\"predict_observed\", \"predict_observed {area} on {date}\"),\n",
    "  (\"predicted\", \"predicted {area} on {date}\"),\n",
    "  (\"actual\", \"actual {area} on {date}\"), \n",
    "]\n",
    "\n",
    "high_quantile = 0.95\n",
    "low_quantile = 0.05\n",
    "\n",
    "for area in areas:\n",
    "  for name, template in to_plot:    \n",
    "    indices = [x for x in range((end_date - start_date).days)]\n",
    "    highs = []\n",
    "    lows = []\n",
    "    means = []\n",
    "    for i in indices:\n",
    "      date = start_date.add(days=i)\n",
    "      datestr = date.format('YYYY/MM/DD')\n",
    "      key = template.format(area = area, date=datestr)\n",
    "      try:\n",
    "        means.append(samples[key].mean())\n",
    "        highs.append(samples[key].quantile(high_quantile))\n",
    "        lows.append(samples[key].quantile(low_quantile))\n",
    "      except KeyError:\n",
    "        means.append(float(\"NaN\"))\n",
    "        highs.append(float(\"NaN\"))\n",
    "        lows.append(float(\"NaN\"))\n",
    "    pyplot.fill_between(indices, lows, highs, label=name, alpha=0.5)\n",
    "    pyplot.plot(indices, means, label=name)\n",
    "  pyplot.title(area)\n",
    "  pyplot.legend()\n",
    "  pyplot.yscale(\"log\")\n",
    "  pyplot.show()"
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "cell_metadata_filter": "-all",
   "main_language": "python",
   "notebook_metadata_filter": "-all"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
